{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6652a472",
   "metadata": {},
   "source": [
    "# L4: Describe-and-Generate game 🖍️"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b21aa74",
   "metadata": {},
   "source": [
    "Load your HF API key and relevant Python libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92f16527-fcab-41e0-bce8-634fae58b1b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import io\n",
    "from IPython.display import Image, display, HTML\n",
    "from PIL import Image\n",
    "import base64 \n",
    "\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "hf_api_key = os.environ['HF_API_KEY']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac9eab7f-4013-42ec-b34e-b413b0f8adb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Helper function\n",
    "import requests, json\n",
    "\n",
    "#Here we are going to call multiple endpoints!\n",
    "def get_completion(inputs, parameters=None, ENDPOINT_URL=\"\"):\n",
    "    headers = {\n",
    "      \"Authorization\": f\"Bearer {hf_api_key}\",\n",
    "      \"Content-Type\": \"application/json\"\n",
    "    }   \n",
    "    data = { \"inputs\": inputs }\n",
    "    if parameters is not None:\n",
    "        data.update({\"parameters\": parameters})\n",
    "    response = requests.request(\"POST\",\n",
    "                                ENDPOINT_URL,\n",
    "                                headers=headers,\n",
    "                                data=json.dumps(data))\n",
    "    return json.loads(response.content.decode(\"utf-8\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52cb1837",
   "metadata": {},
   "outputs": [],
   "source": [
    "#text-to-image\n",
    "TTI_ENDPOINT = os.environ['HF_API_TTI_BASE']\n",
    "#image-to-text\n",
    "ITT_ENDPOINT = os.environ['HF_API_ITT_BASE']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3f9c4d4",
   "metadata": {},
   "source": [
    "## Building your game with `gr.Blocks()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76170d93-0de0-4359-8f13-3d1bcc41f86a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Bringing the functions from lessons 3 and 4!\n",
    "def image_to_base64_str(pil_image):\n",
    "    byte_arr = io.BytesIO()\n",
    "    pil_image.save(byte_arr, format='PNG')\n",
    "    byte_arr = byte_arr.getvalue()\n",
    "    return str(base64.b64encode(byte_arr).decode('utf-8'))\n",
    "\n",
    "def base64_to_pil(img_base64):\n",
    "    base64_decoded = base64.b64decode(img_base64)\n",
    "    byte_stream = io.BytesIO(base64_decoded)\n",
    "    pil_image = Image.open(byte_stream)\n",
    "    return pil_image\n",
    "\n",
    "def captioner(image):\n",
    "    base64_image = image_to_base64_str(image)\n",
    "    result = get_completion(base64_image, None, ITT_ENDPOINT)\n",
    "    return result[0]['generated_text']\n",
    "\n",
    "def generate(prompt):\n",
    "    output = get_completion(prompt, None, TTI_ENDPOINT)\n",
    "    result_image = base64_to_pil(output)\n",
    "    return result_image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd1f3485",
   "metadata": {},
   "source": [
    "### First attempt, just captioning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64e403a9-bfcb-47e8-b741-743d5f4980fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr \n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Describe-and-Generate game 🖍️\")\n",
    "    image_upload = gr.Image(label=\"Your first image\",type=\"pil\")\n",
    "    btn_caption = gr.Button(\"Generate caption\")\n",
    "    caption = gr.Textbox(label=\"Generated caption\")\n",
    "    \n",
    "    btn_caption.click(fn=captioner, inputs=[image_upload], outputs=[caption])\n",
    "\n",
    "gr.close_all()\n",
    "demo.launch(share=True, server_port=int(os.environ['PORT1']))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "966cd4e8",
   "metadata": {},
   "source": [
    "### Let's add generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5ea9a93-9f97-43f8-86ca-2e8e4ea9dda8",
   "metadata": {},
   "outputs": [],
   "source": [
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Describe-and-Generate game 🖍️\")\n",
    "    image_upload = gr.Image(label=\"Your first image\",type=\"pil\")\n",
    "    btn_caption = gr.Button(\"Generate caption\")\n",
    "    caption = gr.Textbox(label=\"Generated caption\")\n",
    "    btn_image = gr.Button(\"Generate image\")\n",
    "    image_output = gr.Image(label=\"Generated Image\")\n",
    "    btn_caption.click(fn=captioner, inputs=[image_upload], outputs=[caption])\n",
    "    btn_image.click(fn=generate, inputs=[caption], outputs=[image_output])\n",
    "\n",
    "gr.close_all()\n",
    "demo.launch(share=True, server_port=int(os.environ['PORT2']))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22da53d1",
   "metadata": {},
   "source": [
    "### Doing it all at once"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "024a4529-7f38-41fb-aab3-c38525b61abe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def caption_and_generate(image):\n",
    "    caption = captioner(image)\n",
    "    image = generate(caption)\n",
    "    return [caption, image]\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Describe-and-Generate game 🖍️\")\n",
    "    image_upload = gr.Image(label=\"Your first image\",type=\"pil\")\n",
    "    btn_all = gr.Button(\"Caption and generate\")\n",
    "    caption = gr.Textbox(label=\"Generated caption\")\n",
    "    image_output = gr.Image(label=\"Generated Image\")\n",
    "\n",
    "    btn_all.click(fn=caption_and_generate, inputs=[image_upload], outputs=[caption, image_output])\n",
    "\n",
    "gr.close_all()\n",
    "demo.launch(share=True, server_port=int(os.environ['PORT3']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba8bd128-fea1-4c92-8363-cdfba40aed8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gr.close_all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dafaaa3-e48a-4625-a324-934fc37eaa58",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
